# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""X3D stem helper."""

import mindspore.nn as nn


class VideoModelStem(nn.Cell):
    """
    Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool
    on input data tensor for one or multiple pathways.

    Args:
        dim_in (list): the list of channel dimensions of the inputs.
        dim_out (list): the output dimension of the convolution in the stem
            layer.
        kernel (list): the kernels' size of the convolutions in the stem
            layers. Temporal kernel size, height kernel size, width kernel
            size in order.
        stride (list): the stride sizes of the convolutions in the stem
            layer. Temporal kernel stride, height kernel size, width kernel
            size in order.
        padding (list): the paddings' sizes of the convolutions in the stem
            layer. Temporal padding size, height padding size, width padding
            size in order.
        inplace_relu (bool): calculate the relu on the original input
            without allocating new memory.
        eps (float): epsilon for batch norm.
        bn_mmt (float): momentum for batch norm. Noted that BN momentum in
            PyTorch = 1 - BN momentum in Caffe2.
        norm_module (nn.Cell): nn.Cell for the normalization layer. The
            default is nn.BatchNorm3d.

    Returns:
        Tensor

    Examples:
        >>> videoModelStem = VideoModelStem(dim_in=[3], dim_out=[24], kernel=[[5, 3, 3]], stride=[[1, 2, 2]],
        >>>             padding=[[2, 1, 1]], inplace_relu=True, eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d)
    """

    def __init__(self,
                 dim_in,
                 dim_out,
                 kernel,
                 stride,
                 padding,
                 inplace_relu=True,
                 eps=1e-5,
                 bn_mmt=0.1,
                 norm_module=nn.BatchNorm3d,
                 ):
        super(VideoModelStem, self).__init__()
        self.num_pathways = len(dim_in)
        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        self.inplace_relu = inplace_relu
        self.eps = eps
        self.bn_mmt = bn_mmt
        # Construct the stem layer.
        pathway = 0
        self.stem = X3DStem(
            dim_in[pathway],
            dim_out[pathway],
            self.kernel[pathway],
            self.stride[pathway],
            self.padding[pathway],
            self.inplace_relu,
            self.eps,
            self.bn_mmt,
            norm_module,
        )

    def construct(self, x):
        x = self.stem(x)
        return x


class X3DStem(nn.Cell):
    """
    X3D's 3D stem module.
    Performs a spatial followed by a depthwise temporal Convolution, BN, and Relu following by a
        spatiotemporal pooling.

    Args:
            dim_in (int): the channel dimension of the input. Normally 3 is used
                for rgb input, and 2 or 3 is used for optical flow input.
            dim_out (int): the output dimension of the convolution in the stem
                layer.
            kernel (list): the kernel size of the convolution in the stem layer.
                temporal kernel size, height kernel size, width kernel size in
                order.
            stride (list): the stride size of the convolution in the stem layer.
                temporal kernel stride, height kernel size, width kernel size in
                order.
            padding (int): the padding size of the convolution in the stem
                layer, temporal padding size, height padding size, width
                padding size in order.
            inplace_relu (bool): calculate the relu on the original input
                without allocating new memory.
            eps (float): epsilon for batch norm.
            bn_mmt (float): momentum for batch norm. Noted that BN momentum in
                PyTorch = 1 - BN momentum in Caffe2.
            norm_module (nn.Cell): nn.Cell for the normalization layer. The
                default is nn.BatchNorm3d.

    Returns:
        Tensor

    Examples:
        >>> x3dStem = X3DStem(dim_in=[3], dim_out=[24], kernel=[[5, 3, 3]], stride=[[1, 2, 2]], padding=[[2, 1, 1]],
        >>>                 inplace_relu=True, eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d)
    """

    def __init__(self,
                 dim_in,
                 dim_out,
                 kernel,
                 stride,
                 padding,
                 inplace_relu=True,
                 eps=1e-5,
                 bn_mmt=0.1,
                 norm_module=nn.BatchNorm3d,
                 ):
        super(X3DStem, self).__init__()
        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        self.inplace_relu = inplace_relu
        self.eps = eps
        self.bn_mmt = bn_mmt
        # Construct the stem layer.
        self._construct_stem(dim_in, dim_out, norm_module)

    def _construct_stem(self, dim_in, dim_out, norm_module):
        """ Construct the stem layer. """
        self.conv_xy = nn.Conv3d(
            dim_in,
            dim_out,
            kernel_size=(1, self.kernel[1], self.kernel[2]),
            stride=(1, self.stride[1], self.stride[2]),
            pad_mode='pad',
            padding=(
                0,
                0,
                self.padding[1],
                self.padding[1],
                self.padding[2],
                self.padding[2]),
            has_bias=False,
        )
        self.conv = nn.Conv3d(
            dim_out,
            dim_out,
            kernel_size=(self.kernel[0], 1, 1),
            stride=(self.stride[0], 1, 1),
            pad_mode='pad',
            padding=(self.padding[0], self.padding[0], 0, 0, 0, 0),
            has_bias=False,
            # NOTE:for now, mindspore only support group=1
            # group=dim_out,
            group=1
        )
        # batchnorm
        self.bn = norm_module(
            num_features=dim_out, eps=self.eps, momentum=self.bn_mmt
        )
        # relu
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv_xy(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
